"""This file generates some fake PMT waveforms in a terrible, terrible data format

Please do not use this file as an example of good code, it was written rather quickly...

Jelle Aalbers, 19 September 2016
"""
import os

try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x

import numpy as np
from scipy import stats


MEAN_AREA = 4
STD_AREA = 0.1 * MEAN_AREA
NOISE_AMPLITUDE = 0.2
WV_LENGTH = int(1e3)
TMAX = int(1e4)
N_DATA = int(1e4)
DATA_DIR = 'data'
BAD_TIME = (500, 600)

def rate_at(voltage):
    """Photon detection rate(voltage) curve: the goal of our analysis"""
    return 2e-3 * np.clip(voltage - 1, 0, float('inf'))**1.7


def make_waveforms(voltage, n_wvs=1):
    """Create n_wvs fake PMT waveforms at voltage"""
    rate = rate_at(voltage)
    tot_t = WV_LENGTH * n_wvs
        
    w = np.random.normal(0, NOISE_AMPLITUDE, size=tot_t)
    n_pulses = np.random.poisson(tot_t * rate)
    pulse_centers = np.random.randint(0, tot_t, size=n_pulses)
    pulse_heights = np.random.normal(MEAN_AREA, STD_AREA, size=n_pulses)

    # Add pmt pulses
    pulse_shape = np.array([0.2, 0.5, 0.3])

    for i, t in enumerate(pulse_centers):
        h = pulse_heights[i]  
        if t == 0:
            w[:2] += pulse_shape[:2] * h
        elif t == WV_LENGTH -1:
            w[-2:] += pulse_shape[-2:] * h                  
        else:
            w[t-1:t+2] += pulse_shape * h
            
    if n_wvs == 1:
        return w
    return w.reshape(-1, WV_LENGTH)


def v_at(t):
    """Voltage(time) curve during the measurement"""
    return 5 *(0.5 + 0.5 *np.sin(t/1000))

if __name__ == '__main__':
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)

    files_per_dir = int(1e3)
    files_written_to_current_dir = float('inf')
    dir_number = -1
    all_ts = np.linspace(0, TMAX, N_DATA)

    for t in tqdm(all_ts):

        if files_written_to_current_dir > files_per_dir:
            # Create a new dir, reset the files written counter
            files_written_to_current_dir = 0
            dir_number += 1
            dirname = os.path.join(DATA_DIR, 'part_%d' % dir_number)
            if not os.path.exists(dirname):
                os.makedirs(dirname)

        if BAD_TIME[0] <= t <= BAD_TIME[1]:
            # During this time, the voltage was actually all over the place
            v = np.random.uniform(0, 10)
        else:
            # Normally, the voltage follows the curve 
            v = v_at(t)
        wv = make_waveforms(v)

        filename = os.path.join(dirname, '%04f.csv' % t)
        np.savetxt(filename, wv)

        files_written_to_current_dir += 1

    # Create the voltage(time) information file
    t_reduced = np.linspace(0, TMAX, 100)
    np.savetxt(os.path.join(DATA_DIR, 't_v.csv'), 
               np.vstack((t_reduced, v_at(t_reduced))).T)
